/*
 * Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.pkl.core.generator

import com.oracle.truffle.api.dsl.GeneratedBy
import com.palantir.javapoet.ClassName
import com.palantir.javapoet.JavaFile
import com.palantir.javapoet.MethodSpec
import com.palantir.javapoet.TypeSpec
import javax.annotation.processing.AbstractProcessor
import javax.annotation.processing.RoundEnvironment
import javax.lang.model.SourceVersion
import javax.lang.model.element.*
import javax.lang.model.type.TypeMirror

/**
 * Generates a subclass of `org.pkl.core.stdlib.registry.ExternalMemberRegistry` for each stdlib
 * module and a factory to instantiate them. Generated classes are written to
 * `generated/truffle/org/pkl/core/stdlib/registry`.
 *
 * Inputs:
 * - Generated Truffle node classes for stdlib members. These classes are located in subpackages of
 *   `org.pkl.core.stdlib` and identified via their `@GeneratedBy` annotations.
 * - `@PklName` annotations on handwritten node classes from which Truffle node classes are
 *   generated.
 */
class MemberRegistryGenerator : AbstractProcessor() {
  private val truffleNodeClassSuffix = "NodeGen"
  private val truffleNodeFactorySuffix = "NodesFactory"
  private val stdLibPackageName: String = "org.pkl.core.stdlib"
  private val registryPackageName: String = "$stdLibPackageName.registry"
  private val modulePackageName: String = "org.pkl.core.module"

  private val externalMemberRegistryClassName: ClassName =
    ClassName.get(registryPackageName, "ExternalMemberRegistry")
  private val emptyMemberRegistryClassName: ClassName =
    ClassName.get(registryPackageName, "EmptyMemberRegistry")
  private val memberRegistryFactoryClassName: ClassName =
    ClassName.get(registryPackageName, "MemberRegistryFactory")
  private val moduleKeyClassName: ClassName = ClassName.get(modulePackageName, "ModuleKey")
  private val moduleKeysClassName: ClassName = ClassName.get(modulePackageName, "ModuleKeys")

  override fun getSupportedAnnotationTypes(): Set<String> = setOf(GeneratedBy::class.java.name)

  override fun getSupportedSourceVersion(): SourceVersion = SourceVersion.RELEASE_17

  override fun process(annotations: Set<TypeElement>, roundEnv: RoundEnvironment): Boolean {
    if (annotations.isEmpty()) return true

    val nodeClassesByPackage = collectNodeClasses(roundEnv)
    generateRegistryClasses(nodeClassesByPackage)
    generateRegistryFactoryClass(nodeClassesByPackage.keys)

    return true
  }

  private fun collectNodeClasses(roundEnv: RoundEnvironment) =
    roundEnv
      .getElementsAnnotatedWith(GeneratedBy::class.java)
      .asSequence()
      .filterIsInstance<TypeElement>()
      .filter { it.qualifiedName.toString().startsWith(stdLibPackageName) }
      .filter { it.simpleName.toString().endsWith(truffleNodeClassSuffix) }
      .sortedWith(
        compareBy(
          {
            if (it.enclosingElement.kind == ElementKind.PACKAGE) ""
            else it.enclosingElement.simpleName.toString()
          },
          { it.simpleName.toString() },
        )
      )
      .groupBy { processingEnv.elementUtils.getPackageOf(it) }

  private fun generateRegistryClasses(
    nodeClassesByPackage: Map<PackageElement, List<TypeElement>>
  ) {
    for ((pkg, nodeClasses) in nodeClassesByPackage) {
      generateRegistryClass(pkg, nodeClasses)
    }
  }

  private fun generateRegistryClass(pkg: PackageElement, nodeClasses: List<TypeElement>) {
    val pklModuleName = getAnnotatedPklName(pkg) ?: pkg.simpleName.toString()
    val pklModuleNameCapitalized = pklModuleName.capitalize()
    val registryClassName =
      ClassName.get(registryPackageName, "${pklModuleNameCapitalized}MemberRegistry")

    val registryClass =
      TypeSpec.classBuilder(registryClassName)
        .addJavadoc("Generated by {@link ${this::class.qualifiedName}}.")
        .addModifiers(Modifier.FINAL)
        .superclass(externalMemberRegistryClassName)
    val registryClassConstructor = MethodSpec.constructorBuilder()

    for (nodeClass in nodeClasses) {
      val enclosingClass = nodeClass.enclosingElement
      val pklClassName =
        getAnnotatedPklName(enclosingClass)
          ?: enclosingClass.simpleName.toString().removeSuffix(truffleNodeFactorySuffix)
      val pklMemberName =
        getAnnotatedPklName(nodeClass)
          ?: nodeClass.simpleName.toString().removeSuffix(truffleNodeClassSuffix)
      val pklMemberNameQualified =
        when (pklClassName) {
          // By convention, the top-level class containing node classes
          // for *module* members is named `<SimpleModuleName>Nodes`.
          // Example: `BaseNodes` for pkl.base
          pklModuleNameCapitalized -> "pkl.$pklModuleName#$pklMemberName"
          else -> "pkl.$pklModuleName#$pklClassName.$pklMemberName"
        }
      registryClass.addOriginatingElement(nodeClass)
      registryClassConstructor.addStatement(
        "register(\$S, \$T::create)",
        pklMemberNameQualified,
        nodeClass,
      )
    }

    registryClass.addMethod(registryClassConstructor.build())
    val javaFile = JavaFile.builder(registryPackageName, registryClass.build()).build()
    javaFile.writeTo(processingEnv.filer)
  }

  private fun generateRegistryFactoryClass(packages: Collection<PackageElement>) {
    val registryFactoryClass =
      TypeSpec.classBuilder(memberRegistryFactoryClassName)
        .addJavadoc("Generated by {@link ${this::class.qualifiedName}}.")
        .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
    val registryFactoryConstructor = MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE)
    registryFactoryClass.addMethod(registryFactoryConstructor.build())
    val registryFactoryGetMethod =
      MethodSpec.methodBuilder("get")
        .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
        .addParameter(moduleKeyClassName, "moduleKey")
        .returns(externalMemberRegistryClassName)
        .beginControlFlow("if (!\$T.isStdLibModule(moduleKey))", moduleKeysClassName)
        .addStatement("return \$T.INSTANCE", emptyMemberRegistryClassName)
        .endControlFlow()
        .beginControlFlow("switch (moduleKey.getUri().getSchemeSpecificPart())")

    for (pkg in packages) {
      val pklModuleName = getAnnotatedPklName(pkg) ?: pkg.simpleName.toString()
      val pklModuleNameCapitalized = pklModuleName.capitalize()
      val registryClassName =
        ClassName.get(registryPackageName, "${pklModuleNameCapitalized}MemberRegistry")

      // declare dependency on package-info.java (for `@PklName`)
      registryFactoryClass.addOriginatingElement(pkg)
      registryFactoryGetMethod
        .addCode("case \$S:\n", pklModuleName)
        .addStatement("  return new \$T()", registryClassName)
    }

    registryFactoryGetMethod
      .addCode("default:\n")
      .addStatement("  return \$T.INSTANCE", emptyMemberRegistryClassName)
      .endControlFlow()
    registryFactoryClass.addMethod(registryFactoryGetMethod.build())
    val javaFile = JavaFile.builder(registryPackageName, registryFactoryClass.build()).build()
    javaFile.writeTo(processingEnv.filer)
  }

  private fun getAnnotatedPklName(element: Element): String? {
    for (annotation in element.annotationMirrors) {
      when (annotation.annotationType.asElement().simpleName.toString()) {
        "PklName" -> return annotation.elementValues.values.iterator().next().value.toString()
        "GeneratedBy" -> {
          val annotationValue = annotation.elementValues.values.first().value as TypeMirror
          return getAnnotatedPklName(processingEnv.typeUtils.asElement(annotationValue))
        }
      }
    }
    return null
  }

  private fun String.capitalize(): String = replaceFirstChar { it.titlecaseChar() }
}
